Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/JAX] Comm+GEMM Overlap API for TE/JAX #1337

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Nov 15, 2024

Description

>>> Depends on PR #1307 <<<

This PR implements JAX/XLA custom ops and primitives for comm+GEMM overlap kernels in TE/common, and the pure Python/JAX infrastructure required to bootstrap the functionality.

Current limitations and considerations:

  • Requires a distributed launch with 1 process per GPU and execution with jax.distributed.initialize(). JAX does not have its own distributed launch utility like torchrun, so this is typically done with mpirun launch + mpi4py in Python.
  • TE has to be compiled with NVTE_UB_WITH_MPI=1 and Userbuffers has to be bootstrapped with MPI because XLA custom ops cannot execute XLA collectives. Unlike PyTorch, this does not introduce a new dependency because distributed launch with JAX already depends on MPI.
  • Userbuffers communication buffers are allocated outside of the XLA memory pool. Since XLA has no knowledge of these allocations, its memory allowance as a % of total device memory needs to be decreased to avoid OOM issues.

To do:
[x] Implement XLA custom ops w/ both old API and new FFI interfaces.
[x] Extend JAX CollectiveGemmPrimitive to support comm+GEMM overlap.
[x] Implement bootstrapping and utility functions with PyBind11 bindings.
[x] Verify that comm+GEMM overlap extensions do not break non-overlap collective GEMM functionality.
[ ] Add new unit tests for comm+GEMM overlap.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera added enhancement New feature or request jax labels Nov 15, 2024
@denera denera self-assigned this Nov 15, 2024
@huanghua1994 huanghua1994 self-requested a review November 15, 2024 17:35
@denera denera force-pushed the jax-collective-gemm-with-overlap branch 2 times, most recently from 11ad5ec to e44d5cf Compare November 21, 2024 20:38
@denera denera force-pushed the jax-collective-gemm-with-overlap branch from c35b351 to 616e301 Compare December 3, 2024 14:08
denera and others added 19 commits December 5, 2024 21:33
Signed-off-by: Alp Dener <[email protected]>

Added XLA FFI custom op for TE GEMM

Signed-off-by: Alp Dener <[email protected]>

finished GEMM custom op primitive and serial unit test

Signed-off-by: Alp Dener <[email protected]>

fixed GEMM custom op batcher

Signed-off-by: Alp Dener <[email protected]>

fixed output dtype error and contracting dimensions options

Signed-off-by: Alp Dener <[email protected]>

AG overlap working but executes scatter to match outer LHS dim

Signed-off-by: Alp Dener <[email protected]>

both all-gather and all-reduce are now working

Signed-off-by: Alp Dener <[email protected]>

code style

Signed-off-by: Alp Dener <[email protected]>

changed kwargs in abstract to be explicit

Signed-off-by: Alp Dener <[email protected]>

added fwd/bwd implementation for non-fp8 gemm

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
denera and others added 9 commits December 5, 2024 21:34
…TE/JAX

Signed-off-by: Alp Dener <[email protected]>

comm+GEMM overlap API for TE/JAX compiles, untested, but did not break collective GEMM op

Signed-off-by: Alp Dener <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fixed static args

Signed-off-by: Alp Dener <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
@denera denera force-pushed the jax-collective-gemm-with-overlap branch from 39f0375 to c4c608b Compare December 5, 2024 21:34

(out, out_amax, out_scale, pre_gelu_out, _, extra_out) = ( # bias_grad in non-FP8 GEMM
CollectiveGemmPrimitive.outer_primitive.bind(
rhs_t,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera the order of these ops need to change.

comm_overlap_config: Optional[dict] = None,
) -> Tuple[ArrayLike, ...]:
"""FP8 mat-mul with `nvte_cublas_gemm()` custom op."""
out_shape_batched = (*lhs.shape[:-2], lhs.shape[-1], rhs_t.shape[-1])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output shape looks wrong to me. Should be (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]).
@denera

out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32
), "Invalid output amax or scale dtype."
else:
assert out_dtype == lhs_dtype, (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera this assertion is wrong for FP8, and needs to be guarded

if lhs_2d_shape is not None and lhs.ndim > 2:
lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout)
if jax_dtype_is_fp8(lhs.dtype):
lhs = jax.lax.transpose(lhs, (1, 0))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera do we need this transpose on the LHS? It seems wrong to me

@denera denera force-pushed the jax-collective-gemm-with-overlap branch from 8a63f8b to 5a3f4f3 Compare January 25, 2025 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants